{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install langchain_google_vertexai langgraph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Controlled generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Controlled generation using `ChatVertexAI` can be set up using the following arguments:\n",
    "\n",
    "- `response_schema`: Schema in OpenAPI format\n",
    "- `response_mime_type`: This argument must be set to `\"application\\json\"`\n",
    "\n",
    "In the following example, we'll instruct the model to structure its response as an array of objects, each containing the properties city_name and population"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_google_vertexai import ChatVertexAI\n",
    "\n",
    "model = ChatVertexAI(\n",
    "    model_name=\"gemini-1.5-flash-001\", \n",
    "    response_schema = {\n",
    "    \"type\": \"array\",\n",
    "    \"items\": {\n",
    "        \"type\": \"object\",\n",
    "        \"properties\": {\n",
    "            \"city_name\": {\n",
    "                \"type\": \"string\",\n",
    "            },\n",
    "            \"population\": {\n",
    "                \"type\": \"number\"\n",
    "            }\n",
    "        },\n",
    "        \"required\": [\"city_name\", \"population\"],\n",
    "      },\n",
    "    }, \n",
    "    response_mime_type = \"application/json\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now upon model invocation, we will receive an `AIMessage`. The property `content` of the message will be a json string with the specified structure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='[{\"city_name\": \"Paris\", \"population\": 2140526}, {\"city_name\": \"Marseille\", \"population\": 861549}, {\"city_name\": \"Lyon\", \"population\": 516092}]\\n', additional_kwargs={}, response_metadata={'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE'}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE'}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE'}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE'}], 'usage_metadata': {'prompt_token_count': 17, 'candidates_token_count': 60, 'total_token_count': 77, 'cached_content_token_count': 0}, 'finish_reason': 'STOP', 'avg_logprobs': -0.10022381146748861, 'logprobs_result': {'top_candidates': [], 'chosen_candidates': []}}, id='run-6be81dc5-868a-4f35-81cb-bb2c7b77fc29-0', usage_metadata={'input_tokens': 17, 'output_tokens': 60, 'total_tokens': 77})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.invoke(\"What are the 3 biggest cities in france\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Aditionally, we can make use of `langchain_core.ouput_parsers` to create a chain that automatically parses the generated json string into a dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'city_name': 'Paris', 'population': 2141.96},\n",
       " {'city_name': 'Lyon', 'population': 516.09},\n",
       " {'city_name': 'Marseille', 'population': 870.72}]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain_core.output_parsers import SimpleJsonOutputParser\n",
    "\n",
    "parser = SimpleJsonOutputParser()\n",
    "\n",
    "chain = model | parser\n",
    "\n",
    "chain.invoke(\"What are the 3 biggest cities in france\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Controlled generation can also be achieved using Pydantic models rather than OpenAPI schemas. Rather than specifying keyword arguments, we can use the method `with_structured_output` and a Pydanctic model. If we reproduce the example above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "City(city_name='Paris', population=2140526)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pydantic import BaseModel, Field\n",
    "\n",
    "class City(BaseModel):\n",
    "    \"\"\" Represents a city with its population and name.\n",
    "    \"\"\"\n",
    "    city_name: str = Field(description=\"Name of the city\")\n",
    "    population: int = Field(description=\"Total population of the city\")\n",
    "\n",
    "model = ChatVertexAI(model_name=\"gemini-1.5-flash-001\").with_structured_output(City, method=\"json_mode\")\n",
    "\n",
    "model.invoke(\"What are the 3 biggest cities in france\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, the output of the model will be the Pydantic object defined."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
